{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8bc2d278-f559-4a84-b00c-d9499759b114",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:42:01.435430Z",
     "iopub.status.busy": "2024-12-19T14:42:01.434784Z",
     "iopub.status.idle": "2024-12-19T14:42:01.444346Z",
     "shell.execute_reply": "2024-12-19T14:42:01.442151Z",
     "shell.execute_reply.started": "2024-12-19T14:42:01.435381Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c252812-a675-4463-a4b7-066387347af9",
   "metadata": {},
   "source": [
    "### softmax is translation invariant"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12f4c919-5c16-4649-ac0a-c262b0222c53",
   "metadata": {},
   "source": [
    "$$\n",
    "\\sigma(\\mathbf z)_i=\\frac{\\exp(z_i)}{\\sum_{j=1}^K\\exp(z_j)}=\\frac{\\exp(z_i-a)}{\\sum_{j=1}^K\\exp(z_j-a)}=\\frac{\\exp(z_i+a)}{\\sum_{j=1}^K\\exp(z_j+a)}\n",
    "$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "600fcf95-26ea-4d8e-88f5-a5d631520eeb",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:42:13.266866Z",
     "iopub.status.busy": "2024-12-19T14:42:13.266269Z",
     "iopub.status.idle": "2024-12-19T14:42:13.281460Z",
     "shell.execute_reply": "2024-12-19T14:42:13.279200Z",
     "shell.execute_reply.started": "2024-12-19T14:42:13.266820Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 0.2485,  0.1528,  1.5404, -0.4477, -0.9106],\n",
       "        [ 0.3894,  0.4204,  0.0760, -0.2010, -0.5533],\n",
       "        [-0.7372, -0.8981, -0.6507,  0.0449,  0.9385],\n",
       "        [ 0.4495, -0.3065, -0.2285, -0.8812, -0.5640],\n",
       "        [ 1.4365, -0.7236,  0.1877,  0.3703, -3.7927]])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.randn(5, 5)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "57476d09-e383-4900-919f-ba1d96b55826",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:42:22.113815Z",
     "iopub.status.busy": "2024-12-19T14:42:22.113178Z",
     "iopub.status.idle": "2024-12-19T14:42:22.133078Z",
     "shell.execute_reply": "2024-12-19T14:42:22.131281Z",
     "shell.execute_reply.started": "2024-12-19T14:42:22.113768Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1572, 0.1429, 0.5722, 0.0784, 0.0493],\n",
       "        [0.2698, 0.2783, 0.1972, 0.1495, 0.1051],\n",
       "        [0.0955, 0.0813, 0.1041, 0.2088, 0.5103],\n",
       "        [0.3840, 0.1803, 0.1949, 0.1015, 0.1394],\n",
       "        [0.5708, 0.0658, 0.1637, 0.1966, 0.0031]])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.softmax(x, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4668a121-9a4d-4ba4-ab82-b41458c64bc5",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:42:35.671987Z",
     "iopub.status.busy": "2024-12-19T14:42:35.671325Z",
     "iopub.status.idle": "2024-12-19T14:42:35.684635Z",
     "shell.execute_reply": "2024-12-19T14:42:35.682890Z",
     "shell.execute_reply.started": "2024-12-19T14:42:35.671939Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1572, 0.1429, 0.5722, 0.0784, 0.0493],\n",
       "        [0.2698, 0.2783, 0.1972, 0.1495, 0.1051],\n",
       "        [0.0955, 0.0813, 0.1041, 0.2088, 0.5103],\n",
       "        [0.3840, 0.1803, 0.1949, 0.1015, 0.1394],\n",
       "        [0.5708, 0.0658, 0.1637, 0.1966, 0.0031]])"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.softmax(x - 1, dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a2c99cac-d20f-4f3f-9222-1a2da9efa99a",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:42:48.142807Z",
     "iopub.status.busy": "2024-12-19T14:42:48.142186Z",
     "iopub.status.idle": "2024-12-19T14:42:48.153530Z",
     "shell.execute_reply": "2024-12-19T14:42:48.152073Z",
     "shell.execute_reply.started": "2024-12-19T14:42:48.142762Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.1572, 0.1429, 0.5722, 0.0784, 0.0493],\n",
       "        [0.2698, 0.2783, 0.1972, 0.1495, 0.1051],\n",
       "        [0.0955, 0.0813, 0.1041, 0.2088, 0.5103],\n",
       "        [0.3840, 0.1803, 0.1949, 0.1015, 0.1394],\n",
       "        [0.5708, 0.0658, 0.1637, 0.1966, 0.0031]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.softmax(x + 5, dim=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11136cbd-3c65-41cc-84c2-78c4da87da25",
   "metadata": {},
   "source": [
    "### hooked transformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c7a65179-e616-49b3-aac3-98bedf24a55d",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:39:06.310623Z",
     "iopub.status.busy": "2024-12-19T14:39:06.310022Z",
     "iopub.status.idle": "2024-12-19T14:39:06.320467Z",
     "shell.execute_reply": "2024-12-19T14:39:06.318244Z",
     "shell.execute_reply.started": "2024-12-19T14:39:06.310576Z"
    }
   },
   "outputs": [],
   "source": [
    "from transformer_lens import HookedTransformer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "78e010de-3976-47c6-8189-71bc1e3d85aa",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:39:21.653312Z",
     "iopub.status.busy": "2024-12-19T14:39:21.652664Z",
     "iopub.status.idle": "2024-12-19T14:39:26.666633Z",
     "shell.execute_reply": "2024-12-19T14:39:26.665411Z",
     "shell.execute_reply.started": "2024-12-19T14:39:21.653265Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained model gpt2 into HookedTransformer\n"
     ]
    }
   ],
   "source": [
    "# center_unembed = True\n",
    "model = HookedTransformer.from_pretrained(\"gpt2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "93533a24-2965-4546-8afa-7bd50e1bcb50",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:39:59.677856Z",
     "iopub.status.busy": "2024-12-19T14:39:59.677548Z",
     "iopub.status.idle": "2024-12-19T14:39:59.685817Z",
     "shell.execute_reply": "2024-12-19T14:39:59.683625Z",
     "shell.execute_reply.started": "2024-12-19T14:39:59.677836Z"
    }
   },
   "outputs": [],
   "source": [
    "W_U = next(model.unembed.parameters())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "050a21b6-d1bf-4a08-8fba-f0c27245cc60",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:40:05.520904Z",
     "iopub.status.busy": "2024-12-19T14:40:05.520257Z",
     "iopub.status.idle": "2024-12-19T14:40:05.532842Z",
     "shell.execute_reply": "2024-12-19T14:40:05.530700Z",
     "shell.execute_reply.started": "2024-12-19T14:40:05.520856Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([768, 50257])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W_U.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "03ec5cbf-2599-4b86-adeb-979ff1550072",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-12-19T14:40:38.207854Z",
     "iopub.status.busy": "2024-12-19T14:40:38.207227Z",
     "iopub.status.idle": "2024-12-19T14:40:38.255882Z",
     "shell.execute_reply": "2024-12-19T14:40:38.253696Z",
     "shell.execute_reply.started": "2024-12-19T14:40:38.207808Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-7.4006e-10, -2.8843e-09,  1.2809e-09,  5.2184e-10, -2.2771e-10,\n",
       "         2.4289e-09, -7.1122e-08, -2.3151e-09, -6.8313e-10,  5.3607e-10,\n",
       "        -7.2109e-10, -1.4801e-09,  1.8976e-10, -2.3056e-09,  9.8960e-09,\n",
       "        -1.4042e-09, -1.3093e-09, -5.8825e-10,  5.6928e-10,  2.1633e-09,\n",
       "        -3.0362e-10,  3.0362e-10, -2.1253e-09, -1.5181e-10,  5.1140e-09,\n",
       "         9.2033e-10, -2.0684e-09,  5.9774e-10,  6.1292e-09,  1.7837e-09,\n",
       "         4.5542e-10,  6.4803e-09, -2.1443e-09,  7.2109e-10,  2.7041e-10,\n",
       "        -5.7687e-09, -2.9189e-07, -7.5904e-10, -2.9413e-10, -1.5181e-10,\n",
       "        -7.5904e-11, -4.7440e-10,  7.5904e-10, -1.1006e-09,  1.1006e-09,\n",
       "        -1.8976e-10, -5.8825e-10,  2.2771e-10,  1.2334e-10,  9.2413e-09,\n",
       "        -1.1955e-09,  1.2904e-09,  1.9735e-09,  4.9337e-10, -3.0551e-09,\n",
       "         1.6395e-08, -3.1121e-09, -3.4916e-09,  6.0723e-10,  6.4898e-09,\n",
       "        -3.0362e-10,  4.5922e-09,  1.3283e-09,  7.9699e-09, -6.3510e-10,\n",
       "         1.7837e-09,  2.6566e-10, -6.0344e-09,  3.5884e-08,  4.3645e-10,\n",
       "        -4.5542e-10,  9.4880e-10, -1.4232e-09, -1.6699e-09,  2.0874e-10,\n",
       "        -4.0798e-09,  1.2334e-09,  3.1880e-09,  2.1633e-09,  2.3530e-09,\n",
       "        -4.2506e-09,  2.8464e-09, -1.0627e-09, -1.0627e-09, -9.6872e-09,\n",
       "         3.3587e-09,  2.5466e-08, -6.7424e-10,  2.2676e-09, -1.2145e-09,\n",
       "        -3.7952e-11,  2.9318e-09,  9.1085e-10, -1.4422e-09,  3.0172e-09,\n",
       "        -7.5904e-10, -2.3530e-09,  7.5904e-10, -8.0648e-10,  1.1006e-09,\n",
       "         7.5904e-10,  1.2524e-09, -5.2677e-08, -7.4006e-09,  1.5181e-10,\n",
       "         7.5904e-10, -7.5904e-11,  5.7687e-09,  3.7952e-10,  1.9735e-09,\n",
       "         7.5904e-10,  1.7458e-09, -3.3967e-09, -1.6699e-09, -1.9735e-09,\n",
       "         4.8389e-10, -6.6795e-09,  2.6566e-10,  1.3852e-09, -1.2600e-08,\n",
       "        -3.4916e-09, -2.1822e-09, -1.8976e-09,  1.1006e-09,  1.0627e-09,\n",
       "         1.9166e-09,  2.5807e-09, -3.0836e-09,  1.2866e-08, -1.1575e-09,\n",
       "         6.8313e-10, -8.7289e-10,  2.0779e-09, -3.0362e-10,  1.2524e-09,\n",
       "        -1.4801e-09, -1.8217e-09,  4.3645e-10, -1.7434e-10,  1.3473e-08,\n",
       "         1.3188e-09,  2.7705e-09,  1.9507e-08,  1.1765e-09,  1.0816e-09,\n",
       "         4.5922e-09,  9.8675e-10,  1.2145e-09, -2.2771e-10, -1.4611e-09,\n",
       "        -3.4916e-09, -1.8217e-09, -2.6756e-09,  6.5562e-09, -1.4422e-09,\n",
       "         1.7078e-09, -2.7895e-09, -4.2506e-09,  2.5048e-09, -8.6341e-09,\n",
       "         2.0684e-08,  2.5618e-10,  2.2012e-09, -1.1386e-10, -7.9699e-10,\n",
       "         6.8313e-10,  4.2886e-09,  4.7971e-08, -2.1443e-09, -1.7610e-08,\n",
       "         7.4481e-09, -1.1006e-09,  3.2828e-09,  8.3494e-10,  7.7801e-09,\n",
       "         1.7268e-09, -3.0343e-08, -9.4880e-10, -9.4880e-11, -7.0591e-09,\n",
       "        -2.4858e-09,  3.0362e-10,  1.2809e-09, -1.4611e-09,  1.2145e-09,\n",
       "         4.9337e-10,  2.9223e-09,  9.4880e-10, -8.7289e-10,  4.0988e-09,\n",
       "        -2.5523e-09, -4.3834e-09,  1.1386e-10, -4.9907e-09, -8.5392e-11,\n",
       "        -5.4983e-09,  8.0648e-10,  4.9243e-09, -6.4518e-10,  5.5030e-09,\n",
       "        -3.8426e-09,  2.2202e-09,  6.8313e-10,  1.8976e-09, -5.6928e-11,\n",
       "         2.1633e-09,  1.9545e-09, -7.5904e-10,  2.7136e-09, -3.2449e-09,\n",
       "         3.2639e-09,  0.0000e+00, -7.7801e-10,  1.3245e-08, -5.6928e-10,\n",
       "         1.5181e-09, -5.3133e-09, -2.1822e-10, -3.6434e-09, -1.8976e-10,\n",
       "        -1.6130e-10,  9.8675e-10, -3.7952e-10, -1.8976e-09,  3.0172e-09,\n",
       "        -9.4880e-11, -6.8313e-10,  3.4157e-10,  1.0057e-09, -1.6699e-09,\n",
       "         3.9849e-10, -3.2069e-09,  4.3759e-08, -1.8596e-09,  1.2334e-10,\n",
       "        -1.2904e-09,  2.6471e-09, -4.2506e-09, -1.0247e-09,  9.4880e-10,\n",
       "        -6.4518e-10, -2.0115e-09,  3.1880e-09, -6.8313e-10,  2.4669e-10,\n",
       "         2.5048e-09,  2.0874e-09,  3.6054e-10, -1.8596e-09,  2.0874e-09,\n",
       "        -1.8103e-08,  4.2127e-09, -1.8976e-09, -1.0627e-09, -3.0362e-10,\n",
       "         5.1045e-09,  5.3892e-09, -8.7289e-10,  9.8675e-10,  5.6928e-11,\n",
       "         4.1747e-10,  1.7638e-08,  1.8312e-09,  2.2012e-09, -3.8901e-09,\n",
       "         3.0551e-09, -1.2954e-09, -1.0095e-08,  2.4574e-09,  1.5181e-09,\n",
       "        -5.3892e-09,  2.0077e-08,  1.1386e-10,  1.1101e-09, -2.9033e-09,\n",
       "         1.1006e-09, -3.8142e-09,  6.1672e-10,  1.3852e-09,  6.0723e-10,\n",
       "         0.0000e+00, -4.1747e-10,  2.0874e-10,  6.9926e-09,  1.4801e-09,\n",
       "         1.2239e-09, -4.9337e-10, -1.0247e-09,  1.5408e-08,  3.8863e-08,\n",
       "         2.1253e-09,  1.6130e-09,  2.9223e-09,  2.3530e-09, -3.5865e-09,\n",
       "        -9.8675e-10,  1.3663e-09, -1.5940e-09,  2.9982e-09,  7.9699e-10,\n",
       "        -2.2771e-10, -2.7515e-10, -9.2982e-10,  2.3151e-09, -1.9735e-09,\n",
       "        -6.0723e-09,  6.4518e-10, -2.1443e-09, -4.3721e-08,  5.5030e-10,\n",
       "         3.0741e-09,  6.5657e-09, -4.5542e-10,  3.7952e-10, -9.9548e-08,\n",
       "         1.1006e-09,  2.3530e-09, -5.6928e-10,  1.0057e-09, -1.1006e-09,\n",
       "         6.4063e-08, -3.0362e-10,  2.6282e-09, -6.6985e-09,  1.7932e-09,\n",
       "         4.2696e-10,  3.1861e-08,  1.0247e-09, -2.6566e-10, -1.1955e-09,\n",
       "         3.4916e-09,  5.3133e-10,  4.5542e-09, -7.5904e-11,  2.0115e-09,\n",
       "         3.0362e-10, -9.2982e-10, -3.5675e-09, -3.7952e-11, -1.9071e-09,\n",
       "         9.8675e-10,  1.8027e-09,  5.3133e-10,  2.2771e-10, -6.8313e-10,\n",
       "        -3.1121e-09,  5.0097e-09, -4.5163e-09,  8.3494e-10,  7.9699e-10,\n",
       "        -1.3473e-09,  1.6519e-08, -2.7420e-09,  5.7023e-09, -1.5940e-09,\n",
       "        -1.7837e-09, -4.4783e-09, -1.0627e-09, -2.7895e-09,  2.6566e-09,\n",
       "        -1.1196e-09,  2.3910e-09,  3.8066e-08, -1.3473e-09,  2.6566e-10,\n",
       "        -9.9149e-08,  1.6130e-09,  7.5904e-10, -2.1253e-09,  1.1386e-10,\n",
       "         1.4042e-09,  6.4518e-10, -1.2145e-09, -2.1336e-09, -2.6069e-07,\n",
       "         1.1082e-08,  3.6434e-09, -1.1006e-09,  2.3910e-09,  2.3530e-09,\n",
       "         1.8976e-09, -1.5181e-10,  5.4840e-09, -1.7458e-09, -1.8976e-10,\n",
       "         1.8976e-11,  3.0362e-10, -2.7515e-09, -1.4042e-09,  8.6151e-09,\n",
       "        -1.0057e-09,  2.8843e-09,  1.9735e-09,  1.9355e-09, -3.2259e-09,\n",
       "         1.3425e-09, -1.1765e-09,  5.1235e-10,  1.7268e-09,  1.7078e-09,\n",
       "        -1.4279e-09,  8.3494e-10, -5.4651e-09,  2.3910e-09,  1.3852e-09,\n",
       "         8.7289e-10,  2.9223e-09, -3.0741e-09, -8.9794e-08, -3.7952e-10,\n",
       "         4.5542e-10,  2.8843e-09,  1.5181e-09, -3.5106e-10,  5.8825e-10,\n",
       "        -2.8084e-09,  1.0911e-09,  3.4157e-10, -7.5904e-10,  1.6509e-09,\n",
       "         1.3141e-09, -6.4518e-10, -3.3777e-09, -2.8084e-09,  5.3133e-10,\n",
       "         4.9337e-10,  1.3283e-10,  1.2524e-09, -1.6035e-09, -1.3283e-10,\n",
       "         1.5925e-07, -2.0968e-09,  6.0723e-10,  1.2145e-09, -1.2524e-09,\n",
       "         4.9717e-09,  1.9261e-09,  5.4556e-10,  1.2524e-09,  7.4955e-09,\n",
       "         7.0211e-09, -4.2127e-09, -1.1670e-07,  1.2524e-09,  2.6566e-09,\n",
       "         2.2012e-09,  1.4422e-09,  1.6634e-10, -2.6187e-09,  2.4669e-10,\n",
       "         1.6889e-09, -2.3530e-09,  8.9187e-10, -5.1994e-09, -5.9205e-09,\n",
       "         3.5257e-08,  5.0856e-09,  2.8464e-09,  1.7458e-09, -9.1085e-09,\n",
       "         1.7078e-09, -9.4880e-10, -1.8976e-11, -3.9849e-10, -1.8596e-09,\n",
       "        -7.4006e-10,  9.6777e-10,  4.9337e-10, -5.6928e-10,  2.5048e-09,\n",
       "        -8.7289e-10,  2.3530e-09, -2.5048e-09,  1.1006e-09, -2.6187e-09,\n",
       "        -2.2771e-10, -4.5542e-09,  4.3645e-10, -8.9661e-09,  1.2691e-07,\n",
       "         3.0243e-11, -1.0674e-11, -1.8217e-09,  2.3720e-09,  1.3852e-09,\n",
       "        -8.2545e-10, -2.3720e-09, -1.5940e-09,  1.1442e-08, -6.6416e-10,\n",
       "         9.2982e-10,  9.8675e-10, -1.8976e-10,  9.4880e-10,  4.5542e-10,\n",
       "         1.4801e-09, -2.3765e-07,  2.0987e-08, -3.2259e-10,  1.1196e-09,\n",
       "        -2.4858e-09,  2.0115e-09, -2.2771e-10,  9.2641e-08,  1.5845e-09,\n",
       "        -1.5750e-09, -1.8976e-09, -6.8693e-09, -3.2069e-09, -7.5904e-10,\n",
       "        -1.8976e-10,  3.1880e-09, -1.5826e-08,  9.1085e-10,  1.0247e-09,\n",
       "         1.8596e-09,  1.1386e-09, -5.1615e-09,  1.1386e-10,  7.2109e-10,\n",
       "        -3.6624e-09,  4.7060e-09,  8.7764e-10,  7.9699e-10,  1.0721e-09,\n",
       "        -2.2771e-10, -4.1424e-08, -5.7649e-08, -1.7458e-09,  2.0494e-09,\n",
       "         1.0133e-08,  3.4726e-09, -4.7440e-11, -3.7667e-09,  5.9869e-09,\n",
       "        -3.7193e-09,  1.8027e-09,  6.0723e-10, -2.2012e-09,  2.9033e-09,\n",
       "        -1.0911e-08,  1.8976e-10,  2.6566e-10,  4.8389e-08, -2.3293e-09,\n",
       "         6.4518e-10,  2.0115e-09,  5.4081e-10,  1.5926e-08,  4.2506e-09,\n",
       "        -8.5392e-10,  2.5807e-09, -3.8711e-09, -2.0874e-09, -2.3815e-08,\n",
       "        -4.3645e-10,  1.1386e-10, -2.8084e-09, -1.6319e-09,  1.7837e-09,\n",
       "        -1.1613e-08,  6.8313e-10,  2.0874e-09, -1.4422e-09, -1.7078e-09,\n",
       "         6.4518e-10, -9.4880e-09, -7.7042e-09,  9.0705e-09, -2.4858e-09,\n",
       "         0.0000e+00, -4.0039e-09, -3.6054e-09,  1.0627e-09, -2.8084e-09,\n",
       "        -1.3283e-09,  9.1085e-10,  2.8464e-10, -2.5428e-09, -2.5428e-09,\n",
       "        -3.4157e-10,  7.9699e-10, -3.5295e-09, -1.8786e-09,  5.6928e-11,\n",
       "        -1.2714e-09,  4.3834e-09, -3.0362e-10, -1.0778e-08,  1.8217e-09,\n",
       "         1.0247e-09, -5.6928e-11, -1.1006e-09, -5.8446e-09,  4.0893e-09,\n",
       "        -1.2904e-09,  6.0723e-10,  1.4801e-09,  1.4422e-09,  2.4669e-09,\n",
       "        -8.1976e-09, -1.7173e-09,  4.5353e-09,  3.6149e-09, -1.2145e-09,\n",
       "         2.7325e-09,  1.5181e-10, -1.8596e-09,  1.8445e-08, -2.4669e-09,\n",
       "        -6.3190e-09, -4.2506e-09,  1.0152e-09,  2.5807e-09,  7.4006e-10,\n",
       "         1.0702e-08, -7.7422e-09, -3.3208e-09, -2.6756e-09,  6.6416e-10,\n",
       "        -8.7289e-10, -2.6377e-09, -5.1615e-09,  9.1085e-10,  8.9187e-09,\n",
       "        -1.9355e-09,  2.2012e-09,  8.7897e-08,  9.8675e-10, -3.5675e-09,\n",
       "        -2.2392e-09,  9.0610e-10,  6.3949e-09,  2.7136e-09, -2.9413e-09,\n",
       "         4.0153e-08, -7.5904e-10, -2.7420e-09, -2.0304e-09, -1.3663e-09,\n",
       "        -2.0836e-08, -1.5797e-09, -4.4404e-09, -4.5542e-09,  2.8464e-10,\n",
       "        -3.3777e-09, -2.3293e-09, -7.9699e-10, -7.9699e-10, -8.6720e-09,\n",
       "        -7.9699e-10, -6.0723e-10, -9.5829e-10, -3.6434e-09,  2.5807e-09,\n",
       "        -7.9509e-09,  3.1121e-09,  0.0000e+00,  9.6777e-10, -3.4992e-08,\n",
       "        -3.7952e-10,  1.8976e-11,  1.8217e-09, -3.7952e-11,  0.0000e+00,\n",
       "        -1.1613e-08,  4.1747e-10, -2.2961e-09, -1.9052e-08,  4.0039e-09,\n",
       "         1.5693e-08, -3.6813e-09, -5.5410e-09, -1.8976e-09, -1.0627e-09,\n",
       "         9.1085e-10,  1.9735e-09,  1.3283e-09, -9.6018e-09, -1.4294e-08,\n",
       "        -5.4840e-09, -7.9699e-10, -9.8675e-10,  6.4518e-10,  1.4327e-09,\n",
       "        -2.0115e-09,  3.1880e-09, -1.2809e-09, -3.8331e-09,  1.1765e-09,\n",
       "         5.8825e-10,  3.7003e-10, -1.8217e-09, -1.1955e-09,  7.5904e-10,\n",
       "         3.0362e-10,  1.7458e-09, -2.3151e-09,  2.4858e-09, -9.5829e-10,\n",
       "        -6.2621e-10,  1.5560e-09, -1.5181e-10,  3.1121e-08, -3.2259e-10,\n",
       "        -1.8976e-10,  7.0970e-09, -5.7687e-09, -6.4518e-10, -2.2041e-08,\n",
       "         7.0211e-10, -3.5390e-09,  1.8976e-09, -2.6946e-09,  4.0988e-09,\n",
       "        -1.7648e-09,  8.5392e-10, -1.4422e-09, -1.0930e-08,  1.8502e-10,\n",
       "        -2.5428e-09, -5.2962e-08,  1.8596e-09, -1.1718e-09, -1.5181e-09,\n",
       "         1.5181e-10,  6.2810e-09,  3.9849e-09,  6.4518e-09, -2.3364e-09,\n",
       "        -1.8596e-09,  8.3304e-09, -2.4669e-10, -1.1386e-09,  4.7440e-10,\n",
       "         1.0437e-09, -7.4006e-10, -1.0057e-09, -1.7458e-09,  3.6624e-09,\n",
       "        -1.0627e-09, -1.1765e-09,  2.6187e-09, -3.1121e-09, -7.5904e-11,\n",
       "         1.9128e-08, -1.0437e-08,  2.6946e-09,  1.2524e-09,  2.8464e-09,\n",
       "        -1.5892e-09, -4.4783e-09,  2.0494e-09,  1.3663e-09,  1.8217e-09,\n",
       "         1.3663e-09,  4.1747e-09,  6.4518e-10,  5.8019e-09,  2.8464e-09,\n",
       "        -1.9545e-09,  1.5181e-09,  2.9906e-08, -4.1368e-09, -6.2146e-10,\n",
       "         5.3133e-10, -2.1063e-09, -1.2239e-09], device='cuda:0',\n",
       "       grad_fn=<MeanBackward1>)"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W_U.mean(dim=-1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
